{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compound representation learning and property prediction\n",
    "\n",
    "In this tuorial, we will go through how to run a Graph Neural Network (GNN) model for compound property prediction. In particular, we will demonstrate how to pretrain and finetune the model in the downstream tasks. If you are intersted in more details, please refer to the README for \"[info graph](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/info_graph)\" and \"[pretrained GNN](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/pretrain_gnns)\".\n",
    "\n",
    "#  Part I: Pretraining\n",
    "\n",
    "In this part, we will show how to pretrain a compound GNN model. The pretraining skills here are adapted from the work of pretrain gnns, including attribute masking, context prediction and supervised pretraining.\n",
    "\n",
    "Visit `pretrain_attrmask.py`, `pretrain_contextpred.py` and `pretrain_supervised.py` for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.insert(0, os.getcwd() + \"/..\")\n",
    "os.chdir(\"../apps/pretrained_compound/pretrain_gnns\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Pahelix framework is build upon PaddlePaddle, which is a deep learning framework."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] 2020-12-18 19:57:50,304 [mp_reader.py:   23]:\tujson not install, fail back to use json instead\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.incubate.fleet.collective import fleet\n",
    "from pahelix.datasets import load_zinc_dataset\n",
    "from pahelix.featurizers import PreGNNAttrMaskFeaturizer\n",
    "from pahelix.utils.compound_tools import CompoundConstants\n",
    "from pahelix.model_zoo import PreGNNAttrmaskModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "# switch to paddle static graph mode.\n",
    "paddle.enable_static()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the static graph\n",
    "Basically we build the static graph with Paddle `Program` and `Executor`. Here, we use `model_config` to hold the model configurations. `PreGNNAttrmaskModel` is an unsupervised pretraining model which randomly masks the atom type of a node and then tries to predict the masked atom type. In the meanwhile, we use the Adam optimizer and set the learning rate to be 0.001.\n",
    "\n",
    "To use the GPU for training, please uncomment the line `exe = fluid.Executor(fluid.CUDAPlace(0))`. While the `fluid.CPUPlace()` is for CPU training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:78\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:98\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/networks/gnn_block.py:194\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var mean_0.tmp_0 : LOD_TENSOR.shape(1,).dtype(FP32).stop_gradient(False)\n"
     ]
    }
   ],
   "source": [
    "model_config = {\n",
    "    \"dropout_rate\": 0.5,# dropout rate\n",
    "    \"gnn_type\": \"gin\",  # other choices like \"gat\", \"gcn\".\n",
    "    \"layer_num\": 5,     # the number of gnn layers.\n",
    "}\n",
    "train_prog = fluid.Program()\n",
    "startup_prog = fluid.Program()\n",
    "with fluid.program_guard(train_prog, startup_prog):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = PreGNNAttrmaskModel(model_config=model_config)\n",
    "        model.forward()\n",
    "        opt = fluid.optimizer.Adam(learning_rate=0.001)\n",
    "        opt.minimize(model.loss)\n",
    "\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "# exe = fluid.Executor(fluid.CUDAPlace(0))\n",
    "exe.run(startup_prog)\n",
    "print(model.loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset loading and feature extraction\n",
    "### Download the dataset using wget\n",
    "First, we need to download a small dataset for this demo. If you do not have `wget` on your machine, you could also\n",
    "copy the url below into your web browser to download the data. But remember to copy the data manually to the\n",
    "path \"../apps/pretrained_compound/pretrain_gnns/\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-18 19:58:05--  https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\n",
      "Connecting to 172.19.56.199:3128... connected.\n",
      "WARNING: certificate common name “*.bcebos.com” doesn’t match requested host name “baidu-nlp.bj.bcebos.com”.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 609563 (595K) [application/gzip]\n",
      "Saving to: “PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz”\n",
      "\n",
      "100%[======================================>] 609,563      231K/s   in 2.6s    \n",
      "\n",
      "2020-12-18 19:58:12 (231 KB/s) - “PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz” saved [609563/609563]\n",
      "\n",
      "tox21  zinc_standard_agent\n"
     ]
    }
   ],
   "source": [
    "### Download a toy dataset for demonstration:\n",
    "!wget \"https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\" --no-check-certificate\n",
    "!tar -zxf \"PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\"\n",
    "!ls \"./chem_dataset_small\"\n",
    "### Download the full dataset as you want:\n",
    "# !wget \"http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip\" --no-check-certificate\n",
    "# !unzip \"chem_dataset.zip\"\n",
    "# !ls \"./chem_dataset\""
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "`PreGNNAttrMaskFeaturizer` is used along with `PreGNNAttrmaskModel`. It inherits from the super class `Featurizer` which is used for feature extractions. The `Featurizer` has two functions: `gen_features` for converting from a single raw SMILES to a single graph data, `collate_fn` for aggregating a sublist of graph data into a big batch.\n",
    "The Zinc dataset is used as the pretraining dataset."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset num: 1000\n"
     ]
    }
   ],
   "source": [
    "featurizer = PreGNNAttrMaskFeaturizer(\n",
    "        model.graph_wrapper, \n",
    "        atom_type_num=len(CompoundConstants.atom_num_list),\n",
    "        mask_ratio=0.15)\n",
    "### Load the first 1000 of the toy dataset for speed up\n",
    "dataset = load_zinc_dataset(\"./chem_dataset_small/zinc_standard_agent/raw\", featurizer=featurizer)\n",
    "dataset = dataset[:1000]\n",
    "### Load the full dataset:\n",
    "# dataset = load_zinc_dataset(\"./chem_dataset/zinc_standard_agent/raw\", featurizer=featurizer)\n",
    "print(\"dataset num: %s\" % (len(dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start train\n",
    "Now we train the attrmask model for 2 epochs for demostration purposes. The data loading process is accelerated with 4 processors. Then the pretrained model is saved to \"./model/pretrain_attrmask\", which will serve as the initial model of the downstream tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0 train/loss:0.7213694\n",
      "epoch:1 train/loss:0.70324224\n"
     ]
    }
   ],
   "source": [
    "def train(exe, train_prog, model, dataset, featurizer):\n",
    "    data_gen = dataset.iter_batch(\n",
    "            batch_size=256, num_workers=4, shuffle=True, collate_fn=featurizer.collate_fn)\n",
    "    list_loss = []\n",
    "    for batch_id, feed_dict in enumerate(data_gen):\n",
    "        train_loss, = exe.run(train_prog, \n",
    "                feed=feed_dict, fetch_list=[model.loss], return_numpy=False)\n",
    "        list_loss.append(np.array(train_loss).mean())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "for epoch_id in range(2):\n",
    "    train_loss = train(exe, train_prog, model, dataset, featurizer)\n",
    "    print(\"epoch:%d train/loss:%s\" % (epoch_id, train_loss))\n",
    "fluid.io.save_params(exe, './model/pretrain_attrmask', train_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above is about the pretraining steps,you can adjust as needed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part II: Downstream finetuning\n",
    "Below we will introduce how to use the pretrained model for the finetuning of downstream tasks.\n",
    "\n",
    "Visit `finetune.py` for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pahelix.utils.paddle_utils import load_partial_params\n",
    "from pahelix.utils.splitters import \\\n",
    "    RandomSplitter, IndexSplitter, ScaffoldSplitter, RandomScaffoldSplitter\n",
    "from pahelix.datasets import *\n",
    "\n",
    "from model import DownstreamModel\n",
    "from featurizer import DownstreamFeaturizer\n",
    "from utils import calc_rocauc_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The downstream datasets are usually small and have different tasks. For example, the BBBP dataset is used for the predictions of the Blood-brain barrier permeability. The Tox21 dataset is used for the predictions of toxicity of compounds. Here we use the Tox21 dataset for demonstrations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']\n"
     ]
    }
   ],
   "source": [
    "task_names = get_default_tox21_task_names()\n",
    "print(task_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the static graph\n",
    "Basically we build the static graph with Paddle Programs and Executors. Here, we use `model_config` to hold the model configurations. Note that the configurations of the model architecture should align with that of the pretraining model, otherwise the loading will fail. `DownstreamModel` is an supervised GNN model which predicts the tasks shown in `task_names`. Meanwhile, we use Adam optimizer and set the lr to be 0.001.\n",
    "\n",
    "We use `train_prog` and `test_prog` to hold static graphs related to training and testing. They have the same network architecture but the functioning of some operators will change, like `Dropout`, `BatchNorm`.\n",
    "\n",
    "To use the GPU for training, please uncomment the `fluid.CUDAPlace(0)`. While the `fluid.CPUPlace()` is for CPU training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/apps/pretrained_compound/pretrain_gnns/model.py:90\n",
      "The behavior of expression A * B has been unified with elementwise_mul(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_mul(X, Y, axis=0) instead of A * B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = {\n",
    "    \"dropout_rate\": 0.5,# dropout rate\n",
    "    \"gnn_type\": \"gin\",  # other choices like \"gat\", \"gcn\".\n",
    "    \"layer_num\": 5,     # the number of gnn layers.\n",
    "    \"num_tasks\": len(task_names), # number of targets to predict for the downstream task.\n",
    "}\n",
    "train_prog = fluid.Program()\n",
    "test_prog = fluid.Program()\n",
    "startup_prog = fluid.Program()\n",
    "with fluid.program_guard(train_prog, startup_prog):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = DownstreamModel(model_config=model_config)\n",
    "        model.train()\n",
    "        opt = fluid.optimizer.Adam(learning_rate=0.001)\n",
    "        opt.minimize(model.loss)\n",
    "with fluid.program_guard(test_prog, fluid.Program()):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = DownstreamModel(model_config=model_config)\n",
    "        model.train(is_test=True)\n",
    "\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "# exe = fluid.Executor(fluid.CUDAPlace(0))\n",
    "exe.run(startup_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load pretrained models\n",
    "Load the pretrained model in the pretraining phase. Here we load the model \"pretrain_attrmask\" as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load parameters from ./model/pretrain_attrmask.\n"
     ]
    }
   ],
   "source": [
    "load_partial_params(exe, './model/pretrain_attrmask', train_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset loading and feature extraction\n",
    "`DownstreamFeaturizer` is used along with `DownstreamModel`. It inherits from the super class `Featurizer` which is used for feature extractions. The `Featurizer` has two functions: `gen_features` for converting from a single raw smiles to a single graph data, `collate_fn` for aggregating a sublist of graph data into a big batch.\n",
    "\n",
    "The Tox21 dataset is used as the downstream dataset and we use `ScaffoldSplitter` to split the dataset into train/valid/test set. `ScaffoldSplitter` will firstly order the compounds according to Bemis-Murcko scaffold, then take the first `frac_train` proportion as the train set, the next `frac_valid` proportion as the valid set and the rest as the test set. `ScaffoldSplitter` can better evaluate the generalization ability of the model on out-of-distribution samples. Note that other splitters like `RandomSplitter`, `RandomScaffoldSplitter` and `IndexSplitter` is also available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [20:05:16] WARNING: not removing hydrogen atom without neighbors\n",
      "RDKit WARNING: [20:05:32] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train/Valid/Test num: 6264/783/784\n"
     ]
    }
   ],
   "source": [
    "featurizer = DownstreamFeaturizer(model.graph_wrapper)\n",
    "### Load the toy dataset:\n",
    "dataset = load_tox21_dataset(\n",
    "        \"./chem_dataset_small/tox21/raw\", task_names, featurizer=featurizer)\n",
    "### Load the full dataset:\n",
    "# dataset = load_tox21_dataset(\n",
    "#         \"./chem_dataset/tox21/raw\", task_names, featurizer=featurizer)\n",
    "\n",
    "# splitter = RandomSplitter()\n",
    "splitter = ScaffoldSplitter()\n",
    "train_dataset, valid_dataset, test_dataset = splitter.split(\n",
    "        dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)\n",
    "print(\"Train/Valid/Test num: %s/%s/%s\" % (\n",
    "        len(train_dataset), len(valid_dataset), len(test_dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start train\n",
    "Now we train the attrmask model for 4 epochs for demostration purposes. Since each downstream task will contain more than one sub-task, the performance of the model is evaluated by the average roc-auc on all sub-tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:0 train/loss:0.49475822\n",
      "epoch:0 val/auc:0.6658935384290235\n",
      "epoch:0 test/auc:0.6521651483374383\n",
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:1 train/loss:0.25100735\n",
      "epoch:1 val/auc:0.6873462784930324\n",
      "epoch:1 test/auc:0.6848870489412003\n",
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:2 train/loss:0.22103228\n",
      "epoch:2 val/auc:0.6841346893995142\n",
      "epoch:2 test/auc:0.6818656406099143\n",
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:3 train/loss:0.21569714\n",
      "epoch:3 val/auc:0.6623265146967362\n",
      "epoch:3 test/auc:0.6324247392008124\n"
     ]
    }
   ],
   "source": [
    "def train(exe, train_prog, model, train_dataset, featurizer):\n",
    "    data_gen = train_dataset.iter_batch(\n",
    "        batch_size=64, num_workers=4, shuffle=True, collate_fn=featurizer.collate_fn)\n",
    "    list_loss = []\n",
    "    for batch_id, feed_dict in enumerate(data_gen):\n",
    "        train_loss, = exe.run(train_prog, feed=feed_dict, fetch_list=[model.loss], return_numpy=False)\n",
    "        list_loss.append(np.array(train_loss).mean())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "def evaluate(exe, test_prog, model, test_dataset, featurizer):\n",
    "    \"\"\"\n",
    "    In the dataset, a proportion of labels are blank. So we use a `valid` tensor\n",
    "    to help eliminate these blank labels in both training and evaluation phase.\n",
    "    \n",
    "    Returns:\n",
    "        the average roc-auc of all sub-tasks.\n",
    "    \"\"\"\n",
    "    data_gen = test_dataset.iter_batch(\n",
    "    \t\tbatch_size=64, num_workers=4, shuffle=False, collate_fn=featurizer.collate_fn)\n",
    "    total_pred = []\n",
    "    total_label = []\n",
    "    total_valid = []\n",
    "    for batch_id, feed_dict in enumerate(data_gen):\n",
    "        pred, = exe.run(test_prog, feed=feed_dict, fetch_list=[model.pred], return_numpy=False)\n",
    "        total_pred.append(np.array(pred))\n",
    "        total_label.append(feed_dict['finetune_label'])\n",
    "        total_valid.append(feed_dict['valid'])\n",
    "    total_pred = np.concatenate(total_pred, 0)\n",
    "    total_label = np.concatenate(total_label, 0)\n",
    "    total_valid = np.concatenate(total_valid, 0)\n",
    "    return calc_rocauc_score(total_label, total_pred, total_valid)\n",
    "\n",
    "for epoch_id in range(4):\n",
    "    train_loss = train(exe, train_prog, model, train_dataset, featurizer)\n",
    "    val_auc = evaluate(exe, test_prog, model, valid_dataset, featurizer)\n",
    "    test_auc = evaluate(exe, test_prog, model, test_dataset, featurizer)\n",
    "    print(\"epoch:%s train/loss:%s\" % (epoch_id, train_loss))\n",
    "    print(\"epoch:%s val/auc:%s\" % (epoch_id, val_auc))\n",
    "    print(\"epoch:%s test/auc:%s\" % (epoch_id, test_auc))\n",
    "fluid.io.save_params(exe, './model/tox21', train_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part III: Downstream Inference\n",
    "In this part, we will briefly introduce how to use the trained downstream model to do inference on the given SMILES sequences."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the static graph\n",
    "This part is the basically the same as the part II."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:78\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:98\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/networks/gnn_block.py:194\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = {\n",
    "    \"dropout_rate\": 0.5,# dropout rate\n",
    "    \"gnn_type\": \"gin\",  # other choices like \"gat\", \"gcn\".\n",
    "    \"layer_num\": 5,     # the number of gnn layers.\n",
    "    \"num_tasks\": len(task_names), # number of targets to predict for the downstream task.\n",
    "}\n",
    "inference_prog = fluid.Program()\n",
    "startup_prog = fluid.Program()\n",
    "with fluid.program_guard(inference_prog, startup_prog):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = DownstreamModel(model_config=model_config)\n",
    "        model.inference()\n",
    "\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "# exe = fluid.Executor(fluid.CUDAPlace(0))\n",
    "exe.run(startup_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load finetuned models\n",
    "Load the finetuned model from part II."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load parameters from ./model/tox21.\n"
     ]
    }
   ],
   "source": [
    "load_partial_params(exe, './model/tox21', inference_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Inference\n",
    "Do inference on the given SMILES sequence. We use directly call `gen_features` and `collate_fn` of the `DownstreamFeaturizer` to convert the raw SMILES sequence to the model input. \n",
    "\n",
    "Using Tox21 dataset as an example, our finetuned downstream model can make predictions over 12 sub-tasks on Tox21."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SMILES:O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\n",
      "Predictions:\n",
      "  NR-AR:\t0.07395526\n",
      "  NR-AR-LBD:\t0.060779173\n",
      "  NR-AhR:\t0.32708064\n",
      "  NR-Aromatase:\t0.105057806\n",
      "  NR-ER:\t0.23960893\n",
      "  NR-ER-LBD:\t0.10899512\n",
      "  NR-PPAR-gamma:\t0.08047223\n",
      "  SR-ARE:\t0.3242342\n",
      "  SR-ATAD5:\t0.103697956\n",
      "  SR-HSE:\t0.08612242\n",
      "  SR-MMP:\t0.321919\n",
      "  SR-p53:\t0.15215957\n"
     ]
    }
   ],
   "source": [
    "SMILES=\"O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\"\n",
    "featurizer = DownstreamFeaturizer(model.graph_wrapper, is_inference=True)\n",
    "feed_dict = featurizer.collate_fn([featurizer.gen_features({'smiles': SMILES})])\n",
    "pred, = exe.run(inference_prog, feed=feed_dict, fetch_list=[model.pred], return_numpy=False)\n",
    "probs = np.array(pred)[0]\n",
    "print('SMILES:%s' % SMILES)\n",
    "print('Predictions:')\n",
    "for name, prob in zip(task_names, probs):\n",
    "    print(\"  %s:\\t%s\" % (name, prob))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}